// package IPsec provides primitives for establishing IPsec in the fastdp mode.
package ipsec

import (
	"crypto/rand"
	"crypto/sha256"
	"encoding/binary"
	"fmt"
	"io"
	"net"
	"strconv"
	"sync"
	"syscall"

	"github.com/coreos/go-iptables/iptables"
	"github.com/pkg/errors"
	"github.com/sirupsen/logrus"
	"github.com/vishvananda/netlink"
	"golang.org/x/crypto/hkdf"

	"github.com/weaveworks/mesh"
)

const (
	keySize   = 36 // AES-GCM key 32 bytes + 4 bytes salt
	nonceSize = 32 // HKDF nonce size

	mark    = uint32(0x1) << 17 // iptables marks
	markStr = "0x20000/0x20000" // update if the above mark changes

	tableMangle  = "mangle"
	tableFilter  = "filter"
	chainIn      = "WEAVE-IPSEC-IN"
	chainInMark  = "WEAVE-IPSEC-IN-MARK"
	chainOut     = "WEAVE-IPSEC-OUT"
	chainOutMark = "WEAVE-IPSEC-OUT-MARK"
)

type SPI uint32

// Used to identify:
// - directional SPIs,
// - ipsec establishments.
type spiID [24]byte

func getSPIId(srcPeer, dstPeer mesh.PeerName, connUID uint64) (id spiID) {
	binary.BigEndian.PutUint64(id[:], uint64(srcPeer))
	binary.BigEndian.PutUint64(id[8:], uint64(dstPeer))
	binary.BigEndian.PutUint64(id[16:], connUID)
	return
}

type spiInfo struct {
	spi      SPI
	isDirOut bool
}

// IPSec

type IPSec struct {
	sync.RWMutex
	ipt *iptables.IPTables
	log *logrus.Logger

	spiInfo map[spiID]spiInfo
	// A reference to spiInfo; spiInfo might be of an expired SPI.
	spis map[SPI]*spiInfo
}

func New(log *logrus.Logger) (*IPSec, error) {
	ipt, err := iptables.New()
	if err != nil {
		return nil, errors.Wrap(err, "iptables new")
	}

	ipsec := &IPSec{
		ipt:     ipt,
		log:     log,
		spiInfo: make(map[spiID]spiInfo),
		spis:    make(map[SPI]*spiInfo),
	}

	return ipsec, nil
}

// InitSALocal initializes inbound ipsec from remotePeer and triggers
// the initialization on remotePeer.
func (ipsec *IPSec) InitSALocal(localPeer, remotePeer mesh.PeerName, connUID uint64, localIP, remoteIP net.IP, udpPort int, sessionKey *[32]byte) ([]byte, error) {
	// ID of inbound SPI
	spiID := getSPIId(remotePeer, localPeer, connUID)

	ipsec.Lock()
	defer ipsec.Unlock()

	// Derive SA key
	nonce, err := genNonce()
	if err != nil {
		return nil, errors.Wrap(err, "generate nonce")
	}
	key, err := deriveKey(sessionKey[:], nonce, localPeer)
	if err != nil {
		return nil, errors.Wrap(err, "derive key")
	}

	// Allocate SA
	sa, err := netlink.XfrmStateAllocSpi(xfrmAllocSpiState(remoteIP, localIP))
	if err != nil {
		return nil, errors.Wrap(err, fmt.Sprintf("ip xfrm state allocspi (in, %s, %s)", remoteIP, localIP))
	}

	// Use SPI generated by the kernel. The kernel ensures {dstIP, spi} to
	// be unique.
	spi := SPI(sa.Spi)

	ipsec.log.Infof("ipsec: InitSALocal: %s -> %s :%d 0x%x", remoteIP, localIP, udpPort, spi)

	// Create SA
	if sa, err := xfrmState(remoteIP, localIP, spi, false, key); err == nil {
		if err := netlink.XfrmStateUpdate(sa); err != nil {
			return nil, errors.Wrap(err, fmt.Sprintf("xfrm state update (in, %s, %s, 0x%x)", sa.Src, sa.Dst, sa.Spi))
		}
	} else {
		return nil, errors.Wrap(err, "new xfrm state (in)")
	}

	// Install iptables rules
	if err := ipsec.installDropNonEncrypted(localIP, remoteIP, udpPort, spi); err != nil {
		return nil, errors.Wrap(err, fmt.Sprintf("install protecting rules (%s, %s, %d, 0x%x)", localIP, remoteIP, udpPort, spi))
	}

	si := spiInfo{spi: spi, isDirOut: false}
	ipsec.spiInfo[spiID] = si
	ipsec.spis[spi] = &si

	// Generate the message to trigger initialization on the remote peer
	msg := &msgInitSARemote{nonce, spi}
	return msg.serialize(), nil
}

// InitSARemote initializes outbound ipsec to remotePeer.
// Triggered by remotePeer.
func (ipsec *IPSec) InitSARemote(msgInitSARemote []byte, localPeer, remotePeer mesh.PeerName, connUID uint64, localIP, remoteIP net.IP, udpPort int, sessionKey *[32]byte) error {
	// ID of outbound SPI
	spiID := getSPIId(localPeer, remotePeer, connUID)

	msg, err := deserializeMsgInitSARemote(msgInitSARemote)
	if err != nil {
		return errors.Wrap(err, "deserialize InitSARemote")
	}
	spi := msg.spi

	ipsec.Lock()
	defer ipsec.Unlock()

	ipsec.log.Infof("ipsec: InitSARemote: %s -> %s :%d 0x%x", localIP, remoteIP, udpPort, spi)

	// Derive SA key by using the received nonce
	key, err := deriveKey(sessionKey[:], msg.nonce, remotePeer)
	if err != nil {
		return errors.Wrap(err, "derive key")
	}

	// Create SA
	if sa, err := xfrmState(localIP, remoteIP, spi, true, key); err == nil {
		if err := netlink.XfrmStateAdd(sa); err != nil {
			return errors.Wrap(err, fmt.Sprintf("xfrm state update (out, %s, %s, 0x%x)", sa.Src, sa.Dst, sa.Spi))
		}
	} else {
		return errors.Wrap(err, "new xfrm state (out)")
	}

	// Create or update SP
	sp := xfrmPolicy(localIP, remoteIP, spi)
	if err := netlink.XfrmPolicyUpdate(sp); err != nil {
		return errors.Wrap(err, fmt.Sprintf("xfrm policy update (%s, %s, 0x%x)", localIP, remoteIP, spi))
	}

	si := spiInfo{spi: spi, isDirOut: true}
	ipsec.spiInfo[spiID] = si
	ipsec.spis[spi] = &si

	return nil
}

// Destroy destroys any (inbound / outbound) ipsec establishment between the peers.
func (ipsec *IPSec) Destroy(localPeer, remotePeer mesh.PeerName, connUID uint64, localIP, remoteIP net.IP, udpPort int) error {
	outSPIID := getSPIId(localPeer, remotePeer, connUID)
	inSPIID := getSPIId(remotePeer, localPeer, connUID)

	ipsec.Lock()
	defer ipsec.Unlock()

	// Destroy inbound

	if inSPIInfo, ok := ipsec.spiInfo[inSPIID]; ok {
		ipsec.log.Infof("ipsec: destroy: in %s -> %s 0x%x", remoteIP, localIP, inSPIInfo.spi)

		inSPI := inSPIInfo.spi

		inSA := &netlink.XfrmState{
			Src:   remoteIP,
			Dst:   localIP,
			Proto: netlink.XFRM_PROTO_ESP,
			Spi:   int(inSPI),
		}
		if err := netlink.XfrmStateDel(inSA); err != nil {
			ipsec.log.Warnf("ipsec: xfrm state del (in, %s, %s, 0x%x) failed: %s", inSA.Src, inSA.Dst, inSA.Spi, err)
		}

		if err := ipsec.removeDropNonEncrypted(localIP, remoteIP, udpPort, inSPI); err != nil {
			ipsec.log.Warnf("ipsec: remove protecting rules (%s, %s, %d, 0x%x) failed: %s", localIP, remoteIP, udpPort, inSPI, err)
		}

		delete(ipsec.spiInfo, inSPIID)
		delete(ipsec.spis, inSPI)
	}

	// Destroy outbound

	if outSPIInfo, ok := ipsec.spiInfo[outSPIID]; ok {
		ipsec.log.Infof("ipsec: destroy: out %s -> %s 0x%x", localIP, remoteIP, outSPIInfo.spi)
		policy, err := netlink.XfrmPolicyGet(xfrmPolicy(localIP, remoteIP, outSPIInfo.spi))
		if err != nil {
			ipsec.log.Warnf("ipsec: xfrm policy get (%s, %s, 0x%x) failed: %s", localIP, remoteIP, outSPIInfo.spi, err)
		} else {
			if len(policy.Tmpls) == 1 {
				if policy.Tmpls[0].Spi == int(outSPIInfo.spi) {
					if err := netlink.XfrmPolicyDel(xfrmPolicy(localIP, remoteIP, outSPIInfo.spi)); err != nil {
						ipsec.log.Warnf("ipsec: xfrm policy del (%s, %s, 0x%x) failed: %s", localIP, remoteIP, outSPIInfo.spi, err)
					}
				} else {
					ipsec.log.Debugf("ipsec: xfrm not my policy (%s, %s, 0x%x) got 0x%x ", localIP, remoteIP, outSPIInfo.spi, policy.Tmpls[0].Spi)
				}
			}
		}

		outSA := &netlink.XfrmState{
			Src:   localIP,
			Dst:   remoteIP,
			Proto: netlink.XFRM_PROTO_ESP,
			Spi:   int(outSPIInfo.spi),
		}
		if err := netlink.XfrmStateDel(outSA); err != nil {
			ipsec.log.Warnf("ipsec: xfrm state del (out, %s, %s, 0x%x) failed: %s", outSA.Src, outSA.Dst, outSA.Spi, err)
		}

		delete(ipsec.spiInfo, outSPIID)
		delete(ipsec.spis, outSPIInfo.spi)
	}

	return nil
}

// Flush removes all policies/SAs established by us. Also, it removes chains and
// rules of iptables.
//
// If destroy is true, the chains and the rules won't be re-created.
func (ipsec *IPSec) Flush(destroy bool) error {
	ipsec.Lock()
	defer ipsec.Unlock()

	policies, err := netlink.XfrmPolicyList(syscall.AF_INET)
	if err != nil {
		return errors.Wrap(err, "xfrm policy list")
	}
	for _, p := range policies {
		if p.Mark != nil && p.Mark.Value == mark && len(p.Tmpls) != 0 {
			spi := SPI(p.Tmpls[0].Spi)
			if err := netlink.XfrmPolicyDel(&p); err != nil {
				return errors.Wrap(err, fmt.Sprintf("xfrm policy del (%s, %s, 0x%x)", p.Src, p.Dst, spi))
			}
		}
	}

	states, err := netlink.XfrmStateList(syscall.AF_INET)
	if err != nil {
		return errors.Wrap(err, "xfrm state list")
	}
	for _, s := range states {
		if _, ok := ipsec.spis[SPI(s.Spi)]; ok {
			if err := netlink.XfrmStateDel(&s); err != nil {
				return errors.Wrap(err, fmt.Sprintf("xfrm state list (%s, %s, 0x%x)", s.Src, s.Dst, s.Spi))
			}
		}
	}

	if err := ipsec.resetIPTables(destroy); err != nil {
		return errors.Wrap(err, "reset ip tables")
	}

	return nil
}

// iptables

type chain struct {
	table string
	chain string
}
type rule struct {
	table    string
	chain    string
	rulespec []string
	unique   bool
}

func (ipsec *IPSec) clearChains(chains []chain) error {
	for _, c := range chains {
		if err := ipsec.ipt.ClearChain(c.table, c.chain); err != nil {
			return errors.Wrap(err, fmt.Sprintf("iptables clear chain (%s, %s)", c.table, c.chain))
		}
	}
	return nil
}

func (ipsec *IPSec) deleteChains(chains []chain) error {
	for _, c := range chains {
		if err := ipsec.ipt.DeleteChain(c.table, c.chain); err != nil {
			return errors.Wrap(err, fmt.Sprintf("iptables delete chain (%s, %s)", c.table, c.chain))
		}
	}
	return nil
}

func (ipsec *IPSec) resetRules(rules []rule, destroy bool) error {
	for _, r := range rules {
		ok, err := ipsec.ipt.Exists(r.table, r.chain, r.rulespec...)
		if err != nil {
			return errors.Wrap(err, fmt.Sprintf("iptables exists rule (%s, %s, %s)", r.table, r.chain, r.rulespec))
		}
		switch {
		case !destroy && !ok:
			if err := ipsec.ipt.Append(r.table, r.chain, r.rulespec...); err != nil {
				return errors.Wrap(err, fmt.Sprintf("iptables append rule (%s, %s, %s)", r.table, r.chain, r.rulespec))
			}
		case destroy && ok:
			if err := ipsec.ipt.Delete(r.table, r.chain, r.rulespec...); err != nil {
				return errors.Wrap(err, fmt.Sprintf("iptables delete rule (%s, %s, %s)", r.table, r.chain, r.rulespec))
			}
		}
	}
	return nil
}

func (ipsec *IPSec) resetIPTables(destroy bool) error {
	chains := []chain{
		{tableMangle, chainIn},
		{tableMangle, chainInMark},
		{tableFilter, chainIn},
		{tableMangle, chainOut},
		{tableMangle, chainOutMark},
	}
	rules := []rule{
		{tableMangle, "INPUT", []string{"-j", chainIn}, true},
		{tableMangle, chainInMark, []string{"-j", "MARK", "--set-xmark", markStr}, true},
		{tableFilter, "INPUT", []string{"-j", chainIn}, true},
		{tableMangle, "OUTPUT", []string{"-j", chainOut}, true},
		{tableMangle, chainOutMark, []string{"-j", "MARK", "--set-xmark", markStr}, true},
		{tableFilter, "OUTPUT",
			[]string{
				"!", "-p", "esp",
				"-m", "policy", "--dir", "out", "--pol", "none",
				"-m", "mark", "--mark", markStr,
				"-j", "DROP"}, true},
	}

	if err := ipsec.clearChains(chains); err != nil {
		return err
	}

	if err := ipsec.resetRules(rules, destroy); err != nil {
		return err
	}

	if destroy {
		if err := ipsec.deleteChains(chains); err != nil {
			return err
		}
	}

	return nil
}

func ruleMarkInboundESP(srcIP, dstIP net.IP, inSPI SPI) rule {
	return rule{tableMangle, chainIn,
		[]string{
			"-s", dstIP.String(), "-d", srcIP.String(),
			"-p", "esp",
			"-m", "esp", "--espspi", "0x" + strconv.FormatUint(uint64(inSPI), 16),
			"-j", chainInMark,
		}, true}
}

func rulesDropNonEncrypted(srcIP, dstIP net.IP, udpPort int, inSPI SPI) []rule {
	udpPortStr := strconv.FormatUint(uint64(udpPort), 10)
	return []rule{
		ruleMarkInboundESP(srcIP, dstIP, inSPI),
		{tableFilter, chainIn,
			[]string{
				"-s", dstIP.String(), "-d", srcIP.String(),
				"-p", "udp", "--dport", udpPortStr,
				"-m", "mark", "!", "--mark", markStr,
				"-j", "DROP",
			}, false},
		{tableMangle, chainOut,
			[]string{
				"-s", srcIP.String(), "-d", dstIP.String(),
				"-p", "udp", "--dport", udpPortStr,
				"-j", chainOutMark,
			}, false},
	}
}

func (ipsec *IPSec) installDropNonEncrypted(srcIP, dstIP net.IP, udpPort int, inSPI SPI) error {
	rules := rulesDropNonEncrypted(srcIP, dstIP, udpPort, inSPI)
	for _, r := range rules {
		appendFunc := ipsec.ipt.Append
		if r.unique {
			appendFunc = ipsec.ipt.AppendUnique
		}
		if err := appendFunc(r.table, r.chain, r.rulespec...); err != nil {
			return errors.Wrap(err, fmt.Sprintf("iptables append (%s, %s, %s)", r.table, r.chain, r.rulespec))
		}
	}
	return nil
}

func (ipsec *IPSec) removeDropNonEncrypted(srcIP, dstIP net.IP, udpPort int, inSPI SPI) error {
	rules := rulesDropNonEncrypted(srcIP, dstIP, udpPort, inSPI)
	return ipsec.resetRules(rules, true)
}

// xfrm

func xfrmAllocSpiState(srcIP, dstIP net.IP) *netlink.XfrmState {
	return &netlink.XfrmState{
		Src:          srcIP,
		Dst:          dstIP,
		Proto:        netlink.XFRM_PROTO_ESP,
		Mode:         netlink.XFRM_MODE_TRANSPORT,
		ReplayWindow: 256,
		ESN:          true,
	}
}

func xfrmState(srcIP, dstIP net.IP, spi SPI, isDirOut bool, key []byte) (*netlink.XfrmState, error) {
	if len(key) != keySize {
		return nil, fmt.Errorf("key should be %d bytes long", keySize)
	}

	state := xfrmAllocSpiState(srcIP, dstIP)

	state.Spi = int(spi)
	state.Aead = &netlink.XfrmStateAlgo{
		Name:   "rfc4106(gcm(aes))",
		Key:    key,
		ICVLen: 128,
	}

	return state, nil
}

func xfrmPolicy(srcIP, dstIP net.IP, spi SPI) *netlink.XfrmPolicy {
	ipMask := []byte{0xff, 0xff, 0xff, 0xff} // /32

	return &netlink.XfrmPolicy{
		Src:   &net.IPNet{IP: srcIP, Mask: ipMask},
		Dst:   &net.IPNet{IP: dstIP, Mask: ipMask},
		Proto: syscall.IPPROTO_UDP,
		Dir:   netlink.XFRM_DIR_OUT,
		Mark: &netlink.XfrmMark{
			Value: mark,
			Mask:  mark,
		},
		Tmpls: []netlink.XfrmPolicyTmpl{
			{
				Src:   srcIP,
				Dst:   dstIP,
				Proto: netlink.XFRM_PROTO_ESP,
				Mode:  netlink.XFRM_MODE_TRANSPORT,
				Spi:   int(spi),
			},
		},
	}
}

// Key derivation

func genNonce() ([]byte, error) {
	buf := make([]byte, nonceSize)
	n, err := rand.Read(buf)
	if err != nil {
		return nil, fmt.Errorf("crypto rand failed: %s", err)
	}
	if n != nonceSize {
		return nil, fmt.Errorf("not enough of random data: %d", n)
	}
	return buf, nil
}

func deriveKey(sessionKey []byte, nonce []byte, peerName mesh.PeerName) ([]byte, error) {
	key := make([]byte, keySize)

	info := make([]byte, 8)
	binary.BigEndian.PutUint64(info, uint64(peerName))

	hkdf := hkdf.New(sha256.New, sessionKey, nonce, info)

	n, err := io.ReadFull(hkdf, key)
	if err != nil {
		return nil, err
	}
	if n != keySize {
		return nil, fmt.Errorf("derived too short key: %d", n)
	}

	return key, nil
}

// Protocol Messages

type msgInitSARemote struct {
	nonce []byte
	spi   SPI
}

func deserializeMsgInitSARemote(b []byte) (*msgInitSARemote, error) {
	if len(b) == 0 {
		return nil, fmt.Errorf("empty msg")
	}

	msg := &msgInitSARemote{}
	if len(b) != msg.size() {
		return nil, fmt.Errorf("invalid payload size: %d", len(b))
	}

	msg.nonce = make([]byte, nonceSize)
	copy(msg.nonce, b[:nonceSize])
	b = b[nonceSize:]

	msg.spi = SPI(binary.BigEndian.Uint32(b))

	return msg, nil
}

func (msg *msgInitSARemote) size() int {
	return nonceSize + 32 // SPI
}

func (msg *msgInitSARemote) serialize() []byte {
	b := make([]byte, msg.size())

	copy(b[:nonceSize], msg.nonce)
	binary.BigEndian.PutUint32(b[nonceSize:], uint32(msg.spi))

	return b
}
